package nsalt.func 

import chisel3._
import chisel3.util._
import chisel3.util.experimental.BoringUtils

import chisel3.experimental.ChiselEnum
// import chisel3.experimental.suppressEnumCastWarning

import nsalt._
import nsalt.bus._
import nsalt.util._

object ExecState extends ChiselEnum {

  val IDLE      = Value 
  val WAIT_TLB  = Value 
  val WAIT_RESP = Value 
  val PART_LOAD = Value 

} 

// Load/Store Unit, the execution part. Found at:
// https://github.com/OSCPU/NutShell/blob/fd86beadfc47f52973270ce6109edebd2a30363b/src/main/scala/nutcore/backend/fu/UnpipelinedLSU.scala#L289
class LoadStoreExec extends Module with Config {

  val io = IO(new LoadStorePort)

  val isAddrAligned = LookupTree(oper(1,0), List(
    "b00".U   -> true.B,              //b
    "b01".U   -> (addr(0) === 0.U),   //h
    "b10".U   -> (addr(1,0) === 0.U), //w
    "b11".U   -> (addr(2,0) === 0.U)  //d
  ))


  val valid = io.in.valid
  val addr  = io.in.bits.src1
  val oper  = io.in.bits.oper

  def access(valid: Bool, addr: UInt, oper: UInt): UInt = {
    this.valid := valid
    this.addr := addr
    this.oper := oper 
    io.out.bits
  }

  def genWmask(addr: UInt, sizeEncode: UInt): UInt = {
    LookupTree(sizeEncode, List(
      "b00".U -> 0x1.U, //0001 << addr(2:0)
      "b01".U -> 0x3.U, //0011
      "b10".U -> 0xf.U, //1111
      "b11".U -> 0xff.U //11111111
    )) << addr(2, 0)
  }
  def genWdata(data: UInt, sizeEncode: UInt): UInt = {
    LookupTree(sizeEncode, List(
      "b00".U -> Fill(8, data(7, 0)),
      "b01".U -> Fill(4, data(15, 0)),
      "b10".U -> Fill(2, data(31, 0)),
      "b11".U -> data
    ))
  }

  def genWmask32(addr: UInt, sizeEncode: UInt): UInt = {
    LookupTree(sizeEncode, List(
      "b00".U -> 0x1.U, //0001 << addr(1:0)
      "b01".U -> 0x3.U, //0011
      "b10".U -> 0xf.U  //1111
    )) << addr(1, 0)
  }
  def genWdata32(data: UInt, sizeEncode: UInt): UInt = {
    LookupTree(sizeEncode, List(
      "b00".U -> Fill(4, data(7, 0)),
      "b01".U -> Fill(2, data(15, 0)),
      "b10".U -> data
    ))
  }

  val dmem = io.dmem

  val addrPrev = RegNext(addr)
  
  val isStore = valid && LoadStoreOperType.isStore(oper)
  val isAMO   = valid && LoadStoreOperType.isAMO(oper)
  val isStoreAMO = isStore || isAMO

  val addrMisalignedL = valid && !isAddrAligned && !isStoreAMO
  val addrMisalignedS = valid && !isAddrAligned &&  isStoreAMO

  val isPartialLoad = !isStore && (oper =/= LoadStoreOperType.ld)

  val state = RegInit(ExecState.IDLE)

  // ==========================================================================
  // Note:
  // The data-mem-TLB is connected in with an unpleasant manner. Considering
  // redisign the ports for connecting to TLB or to memory directly.
  // 
  // We are not using TLB in the very initial version, so the  dtlb-related vars
  // is not used momentarily. 


  val dtlbFinish = WireInit(false.B)
  val dtlbPageFault = WireInit(false.B)
  val dtlbEnable = WireInit(false.B)
  // if (Settings.get("HasDTLB")) {
  //   BoringUtils.addSink(dtlbFinish, "DTLBFINISH")
  //   BoringUtils.addSink(dtlbPageFault, "DTLBPF")
  //   BoringUtils.addSink(dtlbEnable, "DTLBENABLE")
  // }

  io.dtlbPageFault := dtlbPageFault


  // ==========================================================================
  // LoadStoreExec FSM:
  // When currently idle, if data-mem request fired, trans to waiting for TLB or
  // waiting for response from memory. Since we are not using TLB, it will jump
  // to waiting for response.
  //

  switch (state) {
    is (ExecState.IDLE) { 
      when (dmem.req.fire) {
        when (dtlbEnable) {
          state := ExecState.WAIT_TLB 
        }.otherwise{
          state := ExecState.WAIT_RESP
        }
      }
    }
    is (ExecState.WAIT_TLB) {
      when (dtlbFinish) {
        when (dtlbPageFault) {
          state := ExecState.IDLE 
        }.otherwise {
          state := ExecState.WAIT_RESP
        }
      }
    }
    is (ExecState.WAIT_RESP) {
      when (dmem.res.fire) {
        when (isPartialLoad) {
          state := ExecState.PART_LOAD 
        }.otherwise {
          state := ExecState.IDLE 
        }
      }
    }
    is (ExecState.PART_LOAD) {
      state := ExecState.IDLE 
    }
  }

  val size = oper(1,0)

  val reqAddr  = if (XLEN == 32) SignExt(addr, VIRT_MEM_ADDR_LEN) else addr(VIRT_MEM_ADDR_LEN - 1, 0)
  val reqDataW = if (XLEN == 32) genWdata32(io.dataW, size) else genWdata(io.dataW, size)
  val reqMaskW = if (XLEN == 32) genWmask32(addr, size) else genWmask(addr, size)

  dmem.req.bits.apply(
    addr = reqAddr, 
    size = size, 
    dataW = reqDataW,
    maskW = reqMaskW,
    command = Mux(isStore, BusCommand.WRITE, BusCommand.READ)
  )
  
  // Valid when:
  // 1) input valid
  // 2) currently idle
  // 3) neither load addr nor store addr misaligned.
  dmem.req.valid := valid && (state === ExecState.IDLE) && !addrMisalignedL && !addrMisalignedS

  // always ready to receive the response from dmem.
  dmem.res.ready := true.B

  io.out.valid := Mux( dtlbPageFault && state =/= ExecState.IDLE || addrMisalignedL || addrMisalignedS,
    true.B, 
    Mux(isPartialLoad,
      state === ExecState.PART_LOAD, 
      dmem.res.fire && (state === ExecState.WAIT_RESP)
    )
  )

  // io.in.ready := (state === ExecState.IDLE) || dtlbPageFault
  io.in.ready := state === ExecState.IDLE

  val dataRead = dmem.res.bits.dataR
  val dataReadPrev = RegNext(dataRead)

  val dataReadSel64 = LookupTree(addrPrev(2, 0), List(
    "b000".U -> dataReadPrev(63, 0),
    "b001".U -> dataReadPrev(63, 8),
    "b010".U -> dataReadPrev(63, 16),
    "b011".U -> dataReadPrev(63, 24),
    "b100".U -> dataReadPrev(63, 32),
    "b101".U -> dataReadPrev(63, 40),
    "b110".U -> dataReadPrev(63, 48),
    "b111".U -> dataReadPrev(63, 56)
  ))

  val dataReadSel32 = LookupTree(addrPrev(1, 0), List(
    "b00".U -> dataReadPrev(31, 0),
    "b01".U -> dataReadPrev(31, 8),
    "b10".U -> dataReadPrev(31, 16),
    "b11".U -> dataReadPrev(31, 24)
  ))

  val dataReadSel = if (XLEN == 32) dataReadSel32 else dataReadSel64 

  val dataReadPartialLoad = LookupTree(oper, List(
      LoadStoreOperType.lb   -> SignExt(dataReadSel(7, 0) , XLEN),
      LoadStoreOperType.lh   -> SignExt(dataReadSel(15, 0), XLEN),
      LoadStoreOperType.lw   -> SignExt(dataReadSel(31, 0), XLEN),
      LoadStoreOperType.lbu  -> ZeroExt(dataReadSel(7, 0) , XLEN),
      LoadStoreOperType.lhu  -> ZeroExt(dataReadSel(15, 0), XLEN),
      LoadStoreOperType.lwu  -> ZeroExt(dataReadSel(31, 0), XLEN)
  ))

  io.out.bits := Mux(isPartialLoad, dataReadPartialLoad, dataRead(XLEN-1,0))

  io.isMMIO := DontCare

  // BoringUtils.addSource(addr, "LSUADDR")

  io.loadAddrMisaligned  := addrMisalignedL 
  io.storeAddrMisaligned := addrMisalignedS 
}
